{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(numba_for_arviz)=\n",
    "# Numba -  an overview"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Numba](https://numba.pydata.org/numba-doc/latest/index.html) is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops.\n",
    "ArviZ includes {ref}`Numba as an optional dependency <Optional-dependencies>` and a number of functions have been included in `utils.py` for systems in which Numba is pre-installed. Additional functionality, {class}`arviz.Numba`, of disabling/re-enabling numba for systems that have Numba installed has also been included."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A simple example to display the effectiveness of Numba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import arviz as az\n",
    "import numpy as np\n",
    "import timeit\n",
    "\n",
    "from arviz.utils import conditional_jit, Numba\n",
    "from arviz.stats.diagnostics import ks_summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.random.randn(1000000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def variance(data, ddof=0):  # Method to calculate variance without using numba\n",
    "    a_a, b_b = 0, 0\n",
    "    for i in data:\n",
    "        a_a = a_a + i\n",
    "        b_b = b_b + i * i\n",
    "    var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)\n",
    "    var = var * (len(data) / (len(data) - ddof))\n",
    "    return var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "140 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit variance(data, ddof=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@conditional_jit\n",
    "def variance_jit(data, ddof=0):  # Calculating variance with numba\n",
    "    a_a, b_b = 0, 0\n",
    "    for i in data:\n",
    "        a_a = a_a + i\n",
    "        b_b = b_b + i * i\n",
    "    var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)\n",
    "    var = var * (len(data) / (len(data) - ddof))\n",
    "    return var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.03 ms ± 44.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit variance_jit(data, ddof=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That is almost 150 times faster!! Let's compare this to NumPy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.79 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit np.var(data, ddof=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In certain scenarios, Numba can even outperform NumPy! "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Numba within ArviZ\n",
    "Let's see Numba's effect on a few of ArviZ functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_data = np.random.randn(1000, 100, 10)\n",
    "school = az.load_arviz_data(\"centered_eight\").posterior[\"mu\"].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The methods of the {class}`~arviz.Numba` class can be used to enable or disable numba. The attribute `numba_flag` indicates whether numba is enabled within ArviZ or not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Numba.disable_numba()\n",
    "Numba.numba_flag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "57.8 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit ks_summary(summary_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "462 µs ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit ks_summary(school)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Numba.enable_numba()\n",
    "Numba.numba_flag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.18 ms ± 359 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit ks_summary(summary_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "359 µs ± 62.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit ks_summary(school)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Numba has provided a substantial speedup once again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
